/********************************************************
 *  ██████╗  ██████╗████████╗██╗
 * ██╔════╝ ██╔════╝╚══██╔══╝██║
 * ██║  ███╗██║        ██║   ██║
 * ██║   ██║██║        ██║   ██║
 * ╚██████╔╝╚██████╗   ██║   ███████╗
 *  ╚═════╝  ╚═════╝   ╚═╝   ╚══════╝
 * Geophysical Computational Tools & Library (GCTL)
 *
 * Copyright (c) 2023  Yi Zhang (yizhang-geo@zju.edu.cn)
 *
 * GCTL is distributed under a dual licensing scheme. You can redistribute 
 * it and/or modify it under the terms of the GNU Lesser General Public 
 * License as published by the Free Software Foundation, either version 2 
 * of the License, or (at your option) any later version. You should have 
 * received a copy of the GNU Lesser General Public License along with this 
 * program. If not, see <http://www.gnu.org/licenses/>.
 * 
 * If the terms and conditions of the LGPL v.2. would prevent you from using 
 * the GCTL, please consider the option to obtain a commercial license for a 
 * fee. These licenses are offered by the GCTL's original author. As a rule, 
 * licenses are provided "as-is", unlimited in time for a one time fee. Please 
 * send corresponding requests to: yizhang-geo@zju.edu.cn. Please do not forget 
 * to include some description of your company and the realm of its activities. 
 * Also add information on how to contact you by electronic and paper mail.
 ******************************************************/

#ifndef _BOXSORT2D_H
#define _BOXSORT2D_H

#include "../core.h"
#include "heap_sort.h"

namespace gctl
{
    template <typename A>
    struct cargo
    {
        double x, y;
        A* item;
    };

    template <typename B>
    struct box_sort_pair
    {
        int id;
        double near_dist;
        std::vector<cargo<B>> *box_ptr;
    };

    /**
     * @brief       This class implements the template class boxes2d, which preforms the 2D distance based sorting of objects.
     *
     * @tparam     T     template type
     */
    template <typename T>
    class boxes2d
    {
    public:
        boxes2d();
        boxes2d(const array<double> &xs, const array<double> &ys, const array<T> &items, unsigned int x_bnum, unsigned int y_bnum);
        virtual ~boxes2d();
        void init(const array<double> &xs, const array<double> &ys, const array<T> &items, unsigned int x_bnum, unsigned int y_bnum);
        void get_by_index(unsigned int xid, unsigned int yid, std::vector<T*> &cargo_list);
        void get_by_matrix(unsigned int xid, unsigned int yid, unsigned int extend_row, unsigned int extend_col, std::vector<T*> &cargo_list);
        void get_by_radius(double inx, double iny, double inrad, std::vector<T*> &cargo_list, bool on_boundary = false);
        void get_by_number(double inx, double iny, unsigned int target_num, std::vector<T*> &cargo_list);

    protected:
        int xbsize, ybsize, item_size;
        double xmin, xmax, ymin, ymax, dx, dy;
        array<std::vector<cargo<T>>> boxes;
        bool initialized;

        // 以下为get_by_number函数所需变量
        array<box_sort_pair<T>> box_pairs;
        heap_sort<box_sort_pair<T>> boxes_sorter;
        array<double> LocalDist;
        bool box_pairs_initalized;
    };

    template <typename T>
    boxes2d<T>::boxes2d()
    {
        initialized = false;
        box_pairs_initalized = false;
    }

    template <typename T>
    boxes2d<T>::boxes2d(const array<double> &xs, const array<double> &ys, const array<T> &items, unsigned int x_bnum, unsigned int y_bnum) : boxes2d()
    {
        init(xs, ys, items, x_bnum, y_bnum);
    }

    template <typename T>
    boxes2d<T>::~boxes2d()
    {
        if (!boxes.empty())
        {
            for (int i = 0; i < boxes.size(); ++i)
            {
                boxes[i].clear();
                std::vector<cargo<T>>().swap(boxes[i]);
            }
        }
    }

    template <typename T>
    void boxes2d<T>::init(const array<double> &xs, const array<double> &ys, const array<T> &items, unsigned int x_bnum, unsigned int y_bnum)
    {
        if (xs.empty() || ys.empty() || items.empty())
        {
            throw domain_error("Empty arrays. From boxes2d<T>::init(...)");
        }

        if (xs.size() != ys.size() || xs.size() != items.size())
        {
            throw invalid_argument("Arrays' sizes do not match. From boxes2d<T>::init(...)");
        }

        if (x_bnum <= 0 || y_bnum <= 0)
        {
            throw invalid_argument("Invalid boxes dimensions. From boxes2d<T>::init(...)");
        }

        xbsize = x_bnum+2;
        ybsize = y_bnum+2;
        item_size = items.size();
        boxes.resize(xbsize*ybsize);

        xmin = GCTL_BDL_MAX; xmax = GCTL_BDL_MIN;
        ymin = GCTL_BDL_MAX; ymax = GCTL_BDL_MIN;
        for (int i = 0; i < item_size; ++i)
        {
            xmin = std::min(xmin, xs[i]);
            xmax = std::max(xmax, xs[i]);
            ymin = std::min(ymin, ys[i]);
            ymax = std::max(ymax, ys[i]);
        }

        dx = (xmax-xmin)/x_bnum;
        dy = (ymax-ymin)/y_bnum;
        xmin -= dx; xmax += dx;
        ymin -= dy; ymax += dy;

        int M, N;
        cargo<T> tmp_cargo;
        for (int i = 0; i < item_size; ++i)
        {
            N = floor((xs[i]-xmin)/dx);
            M = floor((ys[i]-ymin)/dy);

            tmp_cargo.x = xs[i];
            tmp_cargo.y = ys[i];
            tmp_cargo.item = items.get(i);
            boxes[N + M*xbsize].push_back(tmp_cargo);
        }

        initialized = true;
        return;
    }

    template <typename T>
    void boxes2d<T>::get_by_index(unsigned int xid, unsigned int yid, std::vector<T*> &cargo_list)
    {
        if (!initialized)
        {
            throw runtime_error("The boxes2d object is not initialized. From boxes2d<T>::get_by_index(...)");
        }

        if (xid >= xbsize-2 || yid >= ybsize-2)
        {
            throw out_of_range("Invalid index. From boxes2d<T>::get_by_index(...)");
        }

        if (!cargo_list.empty()) cargo_list.clear();

        int box_id = xid+1 + (yid+1)*xbsize;
        if (boxes[box_id].empty())
        {
            return;
        }

        cargo_list.resize(boxes[box_id].size());
        for (int i = 0; i < cargo_list.size(); ++i)
        {
            cargo_list[i] = boxes[box_id][i].item;
        }
        return;
    }

    template <typename T>
    void boxes2d<T>::get_by_matrix(unsigned int xid, unsigned int yid, unsigned int extend_row, unsigned int extend_col, std::vector<T*> &cargo_list)
    {
        if (!initialized)
        {
            throw runtime_error("The boxes2d object is not initialized. From boxes2d<T>::get_by_matrix(...)");
        }

        if (xid >= xbsize-2 || yid >= ybsize-2)
        {
            throw out_of_range("Invalid index. From boxes2d<T>::get_by_matrix(...)");
        }

        if (!cargo_list.empty()) cargo_list.clear();

        int start_row = yid - extend_row;
        int end_row = yid + extend_row;
        int start_col = xid - extend_col;
        int end_col = xid + extend_col;

        if (start_row < 0) start_row = 0;
        if (end_row > ybsize-1) end_row = ybsize-1;
        if (start_col < 0) start_col = 0;
        if (end_col > xbsize-1) end_col = xbsize-1; 

        int curr_id;
        for (int i = start_row; i <= end_row; ++i)
        {
            for (int j = start_col; j <= end_col; ++j)
            {
                curr_id = j+xbsize*i;
                for (int k = 0; k < boxes[curr_id].size(); ++k)
                {
                    cargo_list.push_back(boxes[curr_id][k].item);
                }
            }
        }
        return;
    }

    template <typename T>
    void boxes2d<T>::get_by_radius(double inx, double iny, double inrad, std::vector<T*> &cargo_list, bool on_boundary)
    {
        if (!initialized)
        {
            throw runtime_error("The boxes2d object is not initialized. From boxes2d<T>::get_by_radius(...)");
        }

        if (inrad <= 0)
        {
            throw invalid_argument("Invalid search radius. From boxes2d<T>::get_by_radius(...)");
        }

        if (inx+inrad <= xmin+dx || inx-inrad >= xmax-dx || iny+inrad <= ymin+dy || iny-inrad >= ymax-dy)
        {
            throw invalid_argument("Invalid inquire coordinates. From boxes2d<T>::get_by_radius(...)");
        }

        if (!cargo_list.empty()) cargo_list.clear();

        // locate the reference box
        int M = floor((iny - ymin)/dy);
        int N = floor((inx - xmin)/dx);

        // determine the matrix size includes in the circle
        int half_size = 0;
        bool covered = false;
        double right_x, left_x, down_y, up_y;
        do
        {
            right_x = (N+1+half_size)*dx;
            left_x = (N-half_size)*dx;
            up_y = (M+1+half_size)*dy;
            down_y = (M-half_size)*dy;
            
            if (inx+inrad <= right_x && inx-inrad >= left_x &&
                iny+inrad <= up_y    && iny-inrad >= down_y)
            {
                covered = true;
            }
            else half_size++;
        }
        while (!covered);

        int start_row = M - half_size;
        int end_row = M + half_size;
        int start_col = N - half_size;
        int end_col = N + half_size;

        if (start_row < 0) start_row = 0;
        if (end_row > ybsize-1) end_row = ybsize-1;
        if (start_col < 0) start_col = 0;
        if (end_col > xbsize-1) end_col = xbsize-1; 

        int curr_id;
        double dist;
        for (int i = start_row; i <= end_row; ++i)
        {
            for (int j = start_col; j <= end_col; ++j)
            {
                curr_id = j+xbsize*i;
                for (int k = 0; k < boxes[curr_id].size(); ++k)
                {
                    dist = sqrt((boxes[curr_id][k].x - inx)*(boxes[curr_id][k].x - inx)
                         + (boxes[curr_id][k].y - iny)*(boxes[curr_id][k].y - iny));

                    if (dist < inrad)
                    {
                        cargo_list.push_back(boxes[curr_id][k].item);
                    }
                    else if (on_boundary && dist == inrad)
                    {
                        cargo_list.push_back(boxes[curr_id][k].item);
                    }
                }
            }
        }

        return;
    }

    template <typename T>
    void boxes2d<T>::get_by_number(double inx, double iny, unsigned int target_num, std::vector<T*> &cargo_list)
    {
        if (!initialized)
        {
            throw runtime_error("The boxes2d object is not initialized. From boxes2d<T>::get_by_number(...)");
        }

        if (target_num <= 0)
        {
            throw invalid_argument("Invalid target number. From boxes2d<T>::get_by_number(...)");
        }

        if (!cargo_list.empty()) cargo_list.clear();
        cargo_list.resize(target_num, nullptr);

        int count = 0;
        if (target_num >= item_size)
        {
            for (int i = 0; i < xbsize*ybsize; ++i)
            {
                for (int j = 0; j < boxes[i].size(); ++j)
                {
                    cargo_list[count] = boxes[i][j].item;
                    count++;
                }
            }
            return;
        }

        // sort all boxes by the closest possible distance
        double near_x, near_y;
        if (!box_pairs_initalized)
        {
            box_pairs.resize(xbsize*ybsize);
            for (int i = 0; i < ybsize; ++i)
            {
                for (int j = 0; j < xbsize; ++j)
                {
                    if (inx >= xmin+j*dx && inx <= xmin+(j+1)*dx)
                    {
                        near_x = 0.0;
                    }
                    else if (inx < xmin+j*dx)
                    {
                        near_x = xmin+j*dx-inx;
                    }
                    else
                    {
                        near_x = inx-xmin-(j+1)*dx;
                    }

                    if (iny >= ymin+i*dy && iny <= ymin+(i+1)*dy)
                    {
                        near_y = 0.0;
                    }
                    else if (iny < ymin+i*dy)
                    {
                        near_y = ymin+i*dy-iny;
                    }
                    else
                    {
                        near_y = iny-ymin-(i+1)*dy;
                    }

                    box_pairs[j+i*xbsize].id = j+i*xbsize;
                    box_pairs[j+i*xbsize].near_dist = sqrt(near_x*near_x + near_y*near_y);
                    box_pairs[j+i*xbsize].box_ptr = boxes.get(j+i*xbsize);
                }
            }

            box_pairs_initalized = true;
        }
        else
        {
            int M, N;
            for (int i = 0; i < box_pairs.size(); ++i)
            {
                M = box_pairs[i].id/xbsize;
                N = box_pairs[i].id%xbsize;

                if (inx >= xmin+N*dx && inx <= xmin+(N+1)*dx)
                {
                    near_x = 0.0;
                }
                else if (inx < xmin+N*dx)
                {
                    near_x = xmin+N*dx-inx;
                }
                else
                {
                    near_x = inx-xmin-(N+1)*dx;
                }

                if (iny >= ymin+M*dy && iny <= ymin+(M+1)*dy)
                {
                    near_y = 0.0;
                }
                else if (iny < ymin+M*dy)
                {
                    near_y = ymin+M*dy-iny;
                }
                else
                {
                    near_y = iny-ymin-(M+1)*dy;
                }

                box_pairs[i].near_dist = sqrt(near_x*near_x + near_y*near_y);
            }
        }
        
        // Sort boxes' pointers by the closest possible distance
        boxes_sorter.execute(box_pairs, [](array<box_sort_pair<T>> &a, int l_id, int r_id)->bool{
            if (a[l_id].near_dist < a[r_id].near_dist) return true;
            else return false;
        });

        double dist, tmp_dist, maxi_dist = GCTL_BDL_MIN, maxi_record_dist = GCTL_BDL_MAX;
        T *tmp_item, *curr_item;

        LocalDist.resize(target_num);

        count = 0;
        for (int i = 0; i < box_pairs.size(); ++i)
        {
            if (!box_pairs[i].box_ptr->empty() && box_pairs[i].near_dist < maxi_record_dist)
            {
                for (int j = 0; j < box_pairs[i].box_ptr->size(); ++j)
                {
                    dist = sqrt((box_pairs[i].box_ptr->at(j).x - inx)*(box_pairs[i].box_ptr->at(j).x - inx)
                         + (box_pairs[i].box_ptr->at(j).y - iny)*(box_pairs[i].box_ptr->at(j).y - iny));

                    curr_item = box_pairs[i].box_ptr->at(j).item;

                    if (count < target_num)
                    {
                        LocalDist[count] = dist;
                        cargo_list[count] = curr_item;

                        if (dist > maxi_dist)
                        {
                            maxi_dist = dist;
                        }

                        count++;
                        continue;
                    }

                    maxi_record_dist = maxi_dist;

                    // 这比排序还要快
                    if (dist < maxi_record_dist)
                    {
                        for (int k = 0; k < target_num; ++k)
                        {
                            if (dist < LocalDist[k])
                            {
                                tmp_dist = LocalDist[k]; LocalDist[k] = dist; dist = tmp_dist;
                                tmp_item = cargo_list[k]; cargo_list[k] = curr_item; curr_item = tmp_item;
                            }
                        }
                        
                        maxi_dist = GCTL_BDL_MIN;
                        for (int k = 0; k < target_num; ++k)
                        {
                            if (LocalDist[k] > maxi_dist)
                            {
                                maxi_dist = LocalDist[k];
                            }
                        }
                    }
                }
            }
        }
        return;
    }
}

#endif // _BOXSORT2D_H